package com.shiku.imserver.common.http;

import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.core.ChannelContext;
import org.tio.core.Tio;
import org.tio.core.exception.AioDecodeException;
import org.tio.http.common.HttpConfig;
import org.tio.http.common.HttpConst;
import org.tio.http.common.Method;
import org.tio.http.common.RequestLine;
import org.tio.http.common.utils.HttpParseUtils;
import org.tio.utils.hutool.StrUtil;

public class HttpRequestDecoder {
  public enum Step {
    firstline, header, body;
  }
  
  private static Logger log = LoggerFactory.getLogger(HttpRequestDecoder.class);
  
  public static final int MAX_LENGTH_OF_HEADER = 20480;
  
  public static final int MAX_LENGTH_OF_HEADERLINE = 20480;
  
  public static final int MAX_LENGTH_OF_REQUESTLINE = 20480;
  
  public static HttpRequest decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext channelContext, HttpConfig httpConfig) throws AioDecodeException {
    Map<String, String> headers = new HashMap<>();
    int contentLength = 0;
    byte[] bodyBytes = null;
    StringBuilder headerSb = null;
    RequestLine firstLine = null;
    boolean appendRequestHeaderString = httpConfig.isAppendRequestHeaderString();
    if (appendRequestHeaderString)
      headerSb = new StringBuilder(512); 
    firstLine = parseRequestLine(buffer, channelContext);
    if (firstLine == null)
      return null; 
    HttpRequest httpRequest = new HttpRequest(channelContext.getClientNode());
    httpRequest.setRequestLine(firstLine);
    httpRequest.setChannelContext(channelContext);
    httpRequest.setHttpConfig(httpConfig);
    boolean headerCompleted = parseHeaderLine(buffer, headers, 0, httpConfig);
    if (!headerCompleted)
      return null; 
    String contentLengthStr = headers.get("content-length");
    if (StrUtil.isBlank(contentLengthStr)) {
      contentLength = 0;
    } else {
      contentLength = Integer.parseInt(contentLengthStr);
      if (contentLength > httpConfig.getMaxLengthOfPostBody())
        throw new AioDecodeException("post body length is too big[" + contentLength + "], max length is " + httpConfig.getMaxLengthOfPostBody() + " byte"); 
    } 
    int headerLength = buffer.position() - position;
    int allNeedLength = headerLength + contentLength;
    if (readableLength < allNeedLength) {
      channelContext.setPacketNeededLength(Integer.valueOf(allNeedLength));
      return null;
    } 
    if (httpConfig.checkHost && 
      !headers.containsKey("host"))
      throw new AioDecodeException("there is no host header"); 
    if (appendRequestHeaderString) {
      httpRequest.setHeaderString(headerSb.toString());
    } else {
      httpRequest.setHeaderString("");
    } 
    httpRequest.setHeaders(headers);
    if (Tio.IpBlacklist.isInBlacklist(channelContext.groupContext, httpRequest.getClientIp()))
      throw new AioDecodeException("[" + httpRequest.getClientIp() + "] in black list"); 
    httpRequest.setContentLength(contentLength);
    String connection = headers.get("connection");
    if (connection != null)
      httpRequest.setConnection(connection.toLowerCase()); 
    if (StrUtil.isNotBlank(firstLine.queryString))
      decodeParams(httpRequest.getParams(), firstLine.queryString, httpRequest.getCharset(), channelContext); 
    if (contentLength != 0) {
      bodyBytes = new byte[contentLength];
      buffer.get(bodyBytes);
      httpRequest.setBody(bodyBytes);
      parseBody(httpRequest, firstLine, bodyBytes, channelContext, httpConfig);
    } 
    return httpRequest;
  }
  
  public static void decodeParams(Map<String, Object[]> params, String queryString, String charset, ChannelContext channelContext) {
    if (StrUtil.isBlank(queryString))
      return; 
    String[] keyvalues = queryString.split("&");
    for (String keyvalue : keyvalues) {
      String[] keyvalueArr = keyvalue.split("=");
      if (keyvalueArr.length == 2) {
        String key = keyvalueArr[0];
        String value = null;
        try {
          value = URLDecoder.decode(keyvalueArr[1], charset);
        } catch (UnsupportedEncodingException e) {
          log.error(channelContext.toString(), e);
        } 
        Object[] existValue = params.get(key);
        if (existValue != null) {
          String[] newExistValue = new String[existValue.length + 1];
          System.arraycopy(existValue, 0, newExistValue, 0, existValue.length);
          newExistValue[newExistValue.length - 1] = value;
          params.put(key, newExistValue);
        } else {
          String[] newExistValue = { value };
          params.put(key, newExistValue);
        } 
      } 
    } 
  }
  
  private static void parseBody(HttpRequest httpRequest, RequestLine firstLine, byte[] bodyBytes, ChannelContext channelContext, HttpConfig httpConfig) throws AioDecodeException {
    String contentType, initboundary;
    parseBodyFormat(httpRequest, httpRequest.getHeaders());
    HttpConst.RequestBodyFormat bodyFormat = httpRequest.getBodyFormat();
    httpRequest.setBody(bodyBytes);
    switch (bodyFormat) {
      case MULTIPART:
        if (log.isInfoEnabled()) {
          String str = null;
          if (bodyBytes != null && bodyBytes.length > 0 && 
            log.isDebugEnabled())
            try {
              str = new String(bodyBytes, httpRequest.getCharset());
              log.debug("{} multipart body value\r\n{}", channelContext, str);
            } catch (UnsupportedEncodingException e) {
              log.error(channelContext.toString(), e);
            }  
        } 
        contentType = httpRequest.getHeader("content-type");
        initboundary = HttpParseUtils.getSubAttribute(contentType, "boundary");
        log.debug("{}, initboundary:{}", channelContext, initboundary);
        HttpMultiBodyDecoder.decode(httpRequest, firstLine, bodyBytes, initboundary, channelContext, httpConfig);
        return;
    } 
    String bodyString = null;
    if (bodyBytes != null && bodyBytes.length > 0)
      try {
        bodyString = new String(bodyBytes, httpRequest.getCharset());
        httpRequest.setBodyString(bodyString);
        if (log.isInfoEnabled())
          log.info("{} body value\r\n{}", channelContext, bodyString); 
      } catch (UnsupportedEncodingException e) {
        log.error(channelContext.toString(), e);
      }  
    if (bodyFormat == HttpConst.RequestBodyFormat.URLENCODED)
      parseUrlencoded(httpRequest, firstLine, bodyBytes, bodyString, channelContext); 
  }
  
  public static void parseBodyFormat(HttpRequest httpRequest, Map<String, String> headers) {
    String contentType = headers.get("content-type");
    String Content_Type = null;
    if (contentType != null)
      Content_Type = contentType.toLowerCase(); 
    if (Content_Type.startsWith("text/plain")) {
      httpRequest.setBodyFormat(HttpConst.RequestBodyFormat.TEXT);
    } else if (Content_Type.startsWith("multipart/form-data")) {
      httpRequest.setBodyFormat(HttpConst.RequestBodyFormat.MULTIPART);
    } else {
      httpRequest.setBodyFormat(HttpConst.RequestBodyFormat.URLENCODED);
    } 
    if (StrUtil.isNotBlank(Content_Type)) {
      String charset = HttpParseUtils.getSubAttribute(Content_Type, "charset");
      if (StrUtil.isNotBlank(charset)) {
        httpRequest.setCharset(charset);
      } else {
        httpRequest.setCharset("utf-8");
      } 
    } 
  }
  
  public static boolean parseHeaderLine(ByteBuffer buffer, Map<String, String> headers, int hasReceivedHeaderLength, HttpConfig httpConfig) throws AioDecodeException {
    if (!buffer.hasArray())
      return parseHeaderLine2(buffer, headers, hasReceivedHeaderLength, httpConfig); 
    byte[] allbs = buffer.array();
    int initPosition = buffer.position();
    int lastPosition = initPosition;
    int remaining = buffer.remaining();
    if (remaining == 0)
      return false; 
    if (remaining > 1) {
      byte b1 = buffer.get();
      byte b2 = buffer.get();
      if (13 == b1 && 10 == b2)
        return true; 
      if (10 == b1)
        return true; 
    } else if (10 == buffer.get()) {
      return true;
    } 
    String name = null;
    String value = null;
    boolean hasValue = false;
    boolean needIteration = false;
    while (buffer.hasRemaining()) {
      byte b = buffer.get();
      if (name == null) {
        if (b == 58) {
          int len = buffer.position() - lastPosition - 1;
          name = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
          continue;
        } 
        if (b == 10) {
          byte lastByte = buffer.get(buffer.position() - 2);
          int len = buffer.position() - lastPosition - 1;
          if (lastByte == 13)
            len = buffer.position() - lastPosition - 2; 
          name = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
          headers.put(name.toLowerCase(), "");
          needIteration = true;
          break;
        } 
        continue;
      } 
      if (value == null) {
        if (b == 10) {
          byte lastByte = buffer.get(buffer.position() - 2);
          int len = buffer.position() - lastPosition - 1;
          if (lastByte == 13)
            len = buffer.position() - lastPosition - 2; 
          value = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
          headers.put(name.toLowerCase(), StrUtil.trimEnd(value));
          needIteration = true;
          break;
        } 
        if (!hasValue && b == 32) {
          lastPosition = buffer.position();
          continue;
        } 
        hasValue = true;
      } 
    } 
    int lineLength = buffer.position() - initPosition;
    if (lineLength > 20480)
      throw new AioDecodeException("header line is too long, max length of header line is 20480"); 
    if (needIteration) {
      int headerLength = lineLength + hasReceivedHeaderLength;
      if (headerLength > 20480)
        throw new AioDecodeException("header is too long, max length of header is 20480"); 
      return parseHeaderLine(buffer, headers, headerLength, httpConfig);
    } 
    return false;
  }
  
  private static boolean parseHeaderLine2(ByteBuffer buffer, Map<String, String> headers, int headerLength, HttpConfig httpConfig) throws AioDecodeException {
    int initPosition = buffer.position();
    int lastPosition = initPosition;
    int remaining = buffer.remaining();
    if (remaining == 0)
      return false; 
    if (remaining > 1) {
      byte b1 = buffer.get();
      byte b2 = buffer.get();
      if (13 == b1 && 10 == b2)
        return true; 
      if (10 == b1)
        return true; 
    } else if (10 == buffer.get()) {
      return true;
    } 
    String name = null;
    String value = null;
    boolean hasValue = false;
    boolean needIteration = false;
    while (buffer.hasRemaining()) {
      byte b = buffer.get();
      if (name == null) {
        if (b == 58) {
          int nowPosition = buffer.position();
          byte[] bs = new byte[nowPosition - lastPosition - 1];
          buffer.position(lastPosition);
          buffer.get(bs);
          name = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
          continue;
        } 
        if (b == 10) {
          int nowPosition = buffer.position();
          byte[] bs = null;
          byte lastByte = buffer.get(nowPosition - 2);
          if (lastByte == 13) {
            bs = new byte[nowPosition - lastPosition - 2];
          } else {
            bs = new byte[nowPosition - lastPosition - 1];
          } 
          buffer.position(lastPosition);
          buffer.get(bs);
          name = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
          headers.put(name.toLowerCase(), null);
          needIteration = true;
          break;
        } 
        continue;
      } 
      if (value == null) {
        if (b == 10) {
          int nowPosition = buffer.position();
          byte[] bs = null;
          byte lastByte = buffer.get(nowPosition - 2);
          if (lastByte == 13) {
            bs = new byte[nowPosition - lastPosition - 2];
          } else {
            bs = new byte[nowPosition - lastPosition - 1];
          } 
          buffer.position(lastPosition);
          buffer.get(bs);
          value = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
          headers.put(name.toLowerCase(), StrUtil.trimEnd(value));
          needIteration = true;
          break;
        } 
        if (!hasValue && b == 32) {
          lastPosition = buffer.position();
          continue;
        } 
        hasValue = true;
      } 
    } 
    if (needIteration) {
      int myHeaderLength = buffer.position() - initPosition;
      if (myHeaderLength > 20480)
        throw new AioDecodeException("header is too long"); 
      return parseHeaderLine(buffer, headers, myHeaderLength + headerLength, httpConfig);
    } 
    if (remaining > 20480)
      throw new AioDecodeException("header line is too long"); 
    return false;
  }
  
  public static RequestLine parseRequestLine(ByteBuffer buffer, ChannelContext channelContext) throws AioDecodeException {
    if (!buffer.hasArray())
      return parseRequestLine2(buffer, channelContext); 
    byte[] allbs = buffer.array();
    int initPosition = buffer.position();
    String methodStr = null;
    String pathStr = null;
    String queryStr = null;
    String protocol = null;
    String version = null;
    int lastPosition = initPosition;
    while (buffer.hasRemaining()) {
      byte b = buffer.get();
      if (methodStr == null) {
        if (b == 32) {
          int len = buffer.position() - lastPosition - 1;
          methodStr = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
        } 
        continue;
      } 
      if (pathStr == null) {
        if (b == 32 || b == 63) {
          int len = buffer.position() - lastPosition - 1;
          pathStr = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
          if (b == 32)
            queryStr = ""; 
        } 
        continue;
      } 
      if (queryStr == null) {
        if (b == 32) {
          int len = buffer.position() - lastPosition - 1;
          queryStr = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
        } 
        continue;
      } 
      if (protocol == null) {
        if (b == 47) {
          int len = buffer.position() - lastPosition - 1;
          protocol = new String(allbs, lastPosition, len);
          lastPosition = buffer.position();
        } 
        continue;
      } 
      if (version == null && 
        b == 10) {
        byte lastByte = buffer.get(buffer.position() - 2);
        int len = buffer.position() - lastPosition - 1;
        if (lastByte == 13)
          len = buffer.position() - lastPosition - 2; 
        version = new String(allbs, lastPosition, len);
        lastPosition = buffer.position();
        RequestLine requestLine = new RequestLine();
        Method method = Method.from(methodStr);
        requestLine.setMethod(method);
        requestLine.setPath(pathStr);
        requestLine.setInitPath(pathStr);
        requestLine.setQueryString(queryStr);
        requestLine.setProtocol(protocol);
        requestLine.setVersion(version);
        return requestLine;
      } 
    } 
    if (buffer.position() - initPosition > 20480)
      throw new AioDecodeException("request line is too long"); 
    return null;
  }
  
  private static RequestLine parseRequestLine2(ByteBuffer buffer, ChannelContext channelContext) throws AioDecodeException {
    int initPosition = buffer.position();
    String methodStr = null;
    String pathStr = null;
    String queryStr = null;
    String protocol = null;
    String version = null;
    int lastPosition = initPosition;
    while (buffer.hasRemaining()) {
      byte b = buffer.get();
      if (methodStr == null) {
        if (b == 32) {
          int nowPosition = buffer.position();
          byte[] bs = new byte[nowPosition - lastPosition - 1];
          buffer.position(lastPosition);
          buffer.get(bs);
          methodStr = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
        } 
        continue;
      } 
      if (pathStr == null) {
        if (b == 32 || b == 63) {
          int nowPosition = buffer.position();
          byte[] bs = new byte[nowPosition - lastPosition - 1];
          buffer.position(lastPosition);
          buffer.get(bs);
          pathStr = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
          if (b == 32)
            queryStr = ""; 
        } 
        continue;
      } 
      if (queryStr == null) {
        if (b == 32) {
          int nowPosition = buffer.position();
          byte[] bs = new byte[nowPosition - lastPosition - 1];
          buffer.position(lastPosition);
          buffer.get(bs);
          queryStr = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
        } 
        continue;
      } 
      if (protocol == null) {
        if (b == 47) {
          int nowPosition = buffer.position();
          byte[] bs = new byte[nowPosition - lastPosition - 1];
          buffer.position(lastPosition);
          buffer.get(bs);
          protocol = new String(bs);
          lastPosition = nowPosition;
          buffer.position(nowPosition);
        } 
        continue;
      } 
      if (version == null && 
        b == 10) {
        int nowPosition = buffer.position();
        byte[] bs = null;
        byte lastByte = buffer.get(nowPosition - 2);
        if (lastByte == 13) {
          bs = new byte[nowPosition - lastPosition - 2];
        } else {
          bs = new byte[nowPosition - lastPosition - 1];
        } 
        buffer.position(lastPosition);
        buffer.get(bs);
        version = new String(bs);
        lastPosition = nowPosition;
        buffer.position(nowPosition);
        RequestLine requestLine = new RequestLine();
        Method method = Method.from(methodStr);
        requestLine.setMethod(method);
        requestLine.setPath(pathStr);
        requestLine.setInitPath(pathStr);
        requestLine.setQueryString(queryStr);
        requestLine.setProtocol(protocol);
        requestLine.setVersion(version);
        return requestLine;
      } 
    } 
    if (buffer.position() - initPosition > 20480)
      throw new AioDecodeException("request line is too long"); 
    return null;
  }
  
  private static void parseUrlencoded(HttpRequest httpRequest, RequestLine firstLine, byte[] bodyBytes, String bodyString, ChannelContext channelContext) {
    decodeParams(httpRequest.getParams(), bodyString, httpRequest.getCharset(), channelContext);
  }
}
